"""
Phase 3: Panel Estimation — Core Japanification Regressions
=============================================================
1. Baseline: Japan_index = γ₁Z₁ + γ₂Z₂ + γ₃Z₃ + β'X + u
2. KAOPEN interaction model
3. Demographic component regressions (Z → growth, Z → inflation, Z → rates)
4. OADR specifications (alternative to Z polynomials)
5. Age-group coefficient recovery from Z polynomials

Input:  japanification/data/processed/japan_panel_indexed.csv
Output: japanification/output/tables/phase3_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
JAPAN_DIR = PROJECT_DIR / "japanification"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"

# Import PanelGLS from multilateral project
sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n{'=' * 70}")
    print(f"  {label}")
    print(f"{'=' * 70}")
    model.summary(feature_names=feature_names)
    result_df = model.to_dataframe(feature_names=feature_names)
    result_df['model'] = label
    return model, result_df


def main():
    print("=" * 70)
    print("PHASE 3: Panel Estimation — Japanification Regressions")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "japan_panel_indexed.csv")
    print(f"Loaded: {len(df):,} obs, {df['iso3'].nunique()} countries")

    all_results = []

    # =================================================================
    # 1. Baseline: Japan_index = γZ + βX + u
    # =================================================================
    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']

    # Use 2-component index for maximum coverage
    dep_var = 'japan_index_2c'

    all_vars = demo_vars + controls
    est = df.dropna(subset=[dep_var] + all_vars).copy()
    print(f"\nBaseline sample: {len(est):,} obs, {est['iso3'].nunique()} countries")

    m1, r1 = fit_and_report(
        est[dep_var].values, est[all_vars].values,
        est['iso3'].values, est['year'].values,
        all_vars, "Model 1: Baseline (2-component index)"
    )
    all_results.append(r1)

    # Also with 3-component index (rate subsample)
    dep_var_3c = 'japan_index_3c'
    est3 = df.dropna(subset=[dep_var_3c] + all_vars).copy()
    print(f"\n3-component sample: {len(est3):,} obs, {est3['iso3'].nunique()} countries")

    m1b, r1b = fit_and_report(
        est3[dep_var_3c].values, est3[all_vars].values,
        est3['iso3'].values, est3['year'].values,
        all_vars, "Model 1b: Baseline (3-component index)"
    )
    all_results.append(r1b)

    # =================================================================
    # 2. KAOPEN interaction model
    # =================================================================
    interaction_vars = ['Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
    avail_interactions = [v for v in interaction_vars if v in df.columns]

    if avail_interactions:
        all_vars_int = demo_vars + controls + avail_interactions
        est_int = df.dropna(subset=[dep_var] + all_vars_int).copy()

        m2, r2 = fit_and_report(
            est_int[dep_var].values, est_int[all_vars_int].values,
            est_int['iso3'].values, est_int['year'].values,
            all_vars_int, "Model 2: KAOPEN Interactions"
        )
        all_results.append(r2)

    # =================================================================
    # 3. Component regressions: Z → each component separately
    # =================================================================
    component_deps = {
        'Z → Growth': 'rgdp_growth',
        'Z → Inflation': 'inflation_japan',
        'Z → Rate': 'rate_japan',
    }

    for label, dep in component_deps.items():
        comp_est = df.dropna(subset=[dep] + all_vars).copy()
        if len(comp_est) < 100:
            print(f"\n  Skipping {label}: insufficient obs ({len(comp_est)})")
            continue

        mc, rc = fit_and_report(
            comp_est[dep].values, comp_est[all_vars].values,
            comp_est['iso3'].values, comp_est['year'].values,
            all_vars, f"Model 3: {label}"
        )
        all_results.append(rc)

    # =================================================================
    # 4. OADR specifications (alternative to Z polynomials)
    # =================================================================
    if 'old_dep' in df.columns:
        df['old_dep_sq'] = df['old_dep'] ** 2
        oadr_vars = ['old_dep', 'old_dep_sq'] + controls
        est_oadr = df.dropna(subset=[dep_var] + oadr_vars).copy()

        m4, r4 = fit_and_report(
            est_oadr[dep_var].values, est_oadr[oadr_vars].values,
            est_oadr['iso3'].values, est_oadr['year'].values,
            oadr_vars, "Model 4: OADR Quadratic"
        )
        all_results.append(r4)

        # OADR with youth dependency too
        if 'youth_dep' in df.columns:
            dep_vars_full = ['old_dep', 'old_dep_sq', 'youth_dep'] + controls
            est_dep = df.dropna(subset=[dep_var] + dep_vars_full).copy()

            m4b, r4b = fit_and_report(
                est_dep[dep_var].values, est_dep[dep_vars_full].values,
                est_dep['iso3'].values, est_dep['year'].values,
                dep_vars_full, "Model 4b: OADR + Youth Dependency"
            )
            all_results.append(r4b)

    # =================================================================
    # 5. Age-group coefficient recovery
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Age-Group Coefficient Recovery")
    print("=" * 70)

    # Use d_n_1 through d_n_17 age bins directly
    age_bins = [f'd_n_{i}' for i in range(1, 18)]
    avail_bins = [b for b in age_bins if b in df.columns]

    if len(avail_bins) >= 10:
        age_vars = avail_bins + controls
        est_age = df.dropna(subset=[dep_var] + age_vars).copy()

        if len(est_age) >= 200:
            m5, r5 = fit_and_report(
                est_age[dep_var].values, est_age[age_vars].values,
                est_age['iso3'].values, est_age['year'].values,
                age_vars, "Model 5: Age-Bin Coefficients"
            )
            all_results.append(r5)

            # Extract age-bin coefficients for plotting
            age_coefs = pd.DataFrame({
                'age_bin': avail_bins,
                'coefficient': m5.beta[:len(avail_bins)],
                'std_error': m5.se[:len(avail_bins)],
                'p_value': m5.pvalues[:len(avail_bins)],
            })
            # Map bin numbers to approximate age ranges
            age_labels = {
                1: '0-4', 2: '5-9', 3: '10-14', 4: '15-19', 5: '20-24',
                6: '25-29', 7: '30-34', 8: '35-39', 9: '40-44', 10: '45-49',
                11: '50-54', 12: '55-59', 13: '60-64', 14: '65-69', 15: '70-74',
                16: '75-79', 17: '80+'
            }
            age_coefs['age_group'] = [age_labels.get(i+1, f'bin_{i+1}')
                                       for i in range(len(avail_bins))]
            age_coefs.to_csv(TABLE_DIR / "phase3_age_coefficients.csv", index=False)

            print("\n  Age-Group Japanification Profile:")
            print(age_coefs[['age_group', 'coefficient', 'std_error', 'p_value']]
                  .to_string(index=False, float_format='%.4f'))

    # =================================================================
    # 6. Cross-validation: Z → CA/GDP (should match Project 1)
    # =================================================================
    if 'ca_gdp' in df.columns:
        ca_est = df.dropna(subset=['ca_gdp'] + demo_vars + controls).copy()
        if len(ca_est) >= 200:
            m_ca, r_ca = fit_and_report(
                ca_est['ca_gdp'].values, ca_est[demo_vars + controls].values,
                ca_est['iso3'].values, ca_est['year'].values,
                demo_vars + controls, "Cross-check: Z → CA/GDP (should match Project 1)"
            )
            all_results.append(r_ca)

    # =================================================================
    # Save all results
    # =================================================================
    results_df = pd.concat(all_results, ignore_index=True)
    results_df.to_csv(TABLE_DIR / "phase3_regression_results.csv", index=False)
    print(f"\n{'=' * 70}")
    print(f"Saved regression results: {TABLE_DIR / 'phase3_regression_results.csv'}")
    print(f"  {len(results_df)} coefficient rows across {results_df['model'].nunique()} models")

    return results_df


if __name__ == "__main__":
    results = main()
